{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "06_pytorch_transfer_learning_exercises.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyNXgsMoZLpp/LR5qPnNG65Z",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "6e25b4bb0d254191a793696a0f4f00ce": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_37424313e66f474da42cfe1b512f09df",
              "IPY_MODEL_58fd00f6a9114192a4fa757c1f669bff",
              "IPY_MODEL_f115ea4b5fad4bb1910fca49ed3da8a1"
            ],
            "layout": "IPY_MODEL_e8eba8e353e940ff9287929e41e4d656"
          }
        },
        "37424313e66f474da42cfe1b512f09df": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_bc33539914a947ee89c271f10ea6a2bb",
            "placeholder": "​",
            "style": "IPY_MODEL_6e03cb60fab94b7e92ce16c8178922dd",
            "value": "100%"
          }
        },
        "58fd00f6a9114192a4fa757c1f669bff": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_5d464254c31d4516899643112fa0e958",
            "max": 21444401,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_06df3ad4b7454556a43b6d61640b12f8",
            "value": 21444401
          }
        },
        "f115ea4b5fad4bb1910fca49ed3da8a1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0bdc7325c839439589a16c88876d6bd5",
            "placeholder": "​",
            "style": "IPY_MODEL_873a483782894789bf0dee546a1b2d50",
            "value": " 20.5M/20.5M [00:00&lt;00:00, 61.8MB/s]"
          }
        },
        "e8eba8e353e940ff9287929e41e4d656": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "bc33539914a947ee89c271f10ea6a2bb": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6e03cb60fab94b7e92ce16c8178922dd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "5d464254c31d4516899643112fa0e958": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "06df3ad4b7454556a43b6d61640b12f8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "0bdc7325c839439589a16c88876d6bd5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "873a483782894789bf0dee546a1b2d50": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ae21171f17de45d895ab7a319dade609": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_f9c60d9c0aed49faa993fd865fb09174",
              "IPY_MODEL_755366e3f75e44c2b7a79bce78d77d11",
              "IPY_MODEL_4a05e8d965124327a2329cf9e1eec984"
            ],
            "layout": "IPY_MODEL_fe93ec079b384ac38a6f4d0e505431ff"
          }
        },
        "f9c60d9c0aed49faa993fd865fb09174": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_88dea77f1bcf44ffb69654515ee34f54",
            "placeholder": "​",
            "style": "IPY_MODEL_2678a567b0414e1d9cfbfc2ecf5ffd30",
            "value": "100%"
          }
        },
        "755366e3f75e44c2b7a79bce78d77d11": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ce621be138a84f33b24c05b2d9cfd5f0",
            "max": 5,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_1fa41d239a3a4845904434d057476a75",
            "value": 5
          }
        },
        "4a05e8d965124327a2329cf9e1eec984": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_f4827c6e36a1463fb0c82347f64230a2",
            "placeholder": "​",
            "style": "IPY_MODEL_cea8f9c48bd8429998352a090173f537",
            "value": " 5/5 [00:31&lt;00:00,  5.81s/it]"
          }
        },
        "fe93ec079b384ac38a6f4d0e505431ff": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "88dea77f1bcf44ffb69654515ee34f54": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2678a567b0414e1d9cfbfc2ecf5ffd30": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ce621be138a84f33b24c05b2d9cfd5f0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1fa41d239a3a4845904434d057476a75": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "f4827c6e36a1463fb0c82347f64230a2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "cea8f9c48bd8429998352a090173f537": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/mrdbourke/pytorch-deep-learning/blob/main/extras/exercises/06_pytorch_transfer_learning_exercises.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 06. PyTorch Transfer Learning Exercises\n",
        "\n",
        "Welcome to the 06. PyTorch Transfer Learning exercise template notebook.\n",
        "\n",
        "There are several questions in this notebook and it's your goal to answer them by writing Python and PyTorch code.\n",
        "\n",
        "> **Note:** There may be more than one solution to each of the exercises, don't worry too much about the *exact* right answer. Try to write some code that works first and then improve it if you can.\n",
        "\n",
        "## Resources and solutions\n",
        "\n",
        "* These exercises/solutions are based on [section 06. PyTorch Transfer Learning](https://www.learnpytorch.io/06_pytorch_transfer_learning/) of the Learn PyTorch for Deep Learning course by Zero to Mastery.\n",
        "\n",
        "**Solutions:** \n",
        "\n",
        "Try to complete the code below *before* looking at these.\n",
        "\n",
        "* See a live [walkthrough of the solutions (errors and all) on YouTube](https://youtu.be/ueLolShyFqs).\n",
        "* See an example [solutions notebook for these exercises on GitHub](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/extras/solutions/06_pytorch_transfer_learning_exercise_solutions.ipynb)."
      ],
      "metadata": {
        "id": "zNqPNlYylluR"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 1. Make predictions on the entire test dataset and plot a confusion matrix for the results of our model compared to the truth labels. \n",
        "* **Note:** You will need to get the dataset and the trained model/retrain the model from notebook 06 to perform predictions.\n",
        "* Check out [03. PyTorch Computer Vision section 10](https://www.learnpytorch.io/03_pytorch_computer_vision/#10-making-a-confusion-matrix-for-further-prediction-evaluation) for ideas."
      ],
      "metadata": {
        "id": "nwmoMhW8IqSu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Import required libraries/code\n",
        "import torch\n",
        "import torchvision\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from torch import nn\n",
        "from torchvision import transforms, datasets\n",
        "\n",
        "# Try to get torchinfo, install it if it doesn't work\n",
        "try:\n",
        "    from torchinfo import summary\n",
        "except:\n",
        "    print(\"[INFO] Couldn't find torchinfo... installing it.\")\n",
        "    !pip install -q torchinfo\n",
        "    from torchinfo import summary\n",
        "\n",
        "# Try to import the going_modular directory, download it from GitHub if it doesn't work\n",
        "try:\n",
        "    from going_modular.going_modular import data_setup, engine\n",
        "except:\n",
        "    # Get the going_modular scripts\n",
        "    print(\"[INFO] Couldn't find going_modular scripts... downloading them from GitHub.\")\n",
        "    !git clone https://github.com/mrdbourke/pytorch-deep-learning\n",
        "    !mv pytorch-deep-learning/going_modular .\n",
        "    !rm -rf pytorch-deep-learning\n",
        "    from going_modular.going_modular import data_setup, engine"
      ],
      "metadata": {
        "id": "nqtAWBUJgaF1",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "14cc75f3-e109-4e84-c941-2598df3557b3"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[INFO] Couldn't find torchinfo... installing it.\n",
            "[INFO] Couldn't find going_modular scripts... downloading them from GitHub.\n",
            "Cloning into 'pytorch-deep-learning'...\n",
            "remote: Enumerating objects: 1708, done.\u001b[K\n",
            "remote: Counting objects: 100% (160/160), done.\u001b[K\n",
            "remote: Compressing objects: 100% (88/88), done.\u001b[K\n",
            "remote: Total 1708 (delta 67), reused 151 (delta 60), pack-reused 1548\u001b[K\n",
            "Receiving objects: 100% (1708/1708), 230.85 MiB | 14.36 MiB/s, done.\n",
            "Resolving deltas: 100% (927/927), done.\n",
            "Checking out files: 100% (124/124), done.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Setup device agnostic code\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "device"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "O10_T_xSKJlf",
        "outputId": "fd30e756-e542-4b2b-d974-1ed38451fecf"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'cuda'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 2
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Get data "
      ],
      "metadata": {
        "id": "nrzg3TaSKLAh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import requests\n",
        "import zipfile\n",
        "\n",
        "from pathlib import Path\n",
        "\n",
        "# Setup path to data folder\n",
        "data_path = Path(\"data/\")\n",
        "image_path = data_path / \"pizza_steak_sushi\"\n",
        "\n",
        "# If the image folder doesn't exist, download it and prepare it... \n",
        "if image_path.is_dir():\n",
        "    print(f\"{image_path} directory exists.\")\n",
        "else:\n",
        "    print(f\"Did not find {image_path} directory, creating one...\")\n",
        "    image_path.mkdir(parents=True, exist_ok=True)\n",
        "    \n",
        "    # Download pizza, steak, sushi data\n",
        "    with open(data_path / \"pizza_steak_sushi.zip\", \"wb\") as f:\n",
        "        request = requests.get(\"https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip\")\n",
        "        print(\"Downloading pizza, steak, sushi data...\")\n",
        "        f.write(request.content)\n",
        "\n",
        "    # Unzip pizza, steak, sushi data\n",
        "    with zipfile.ZipFile(data_path / \"pizza_steak_sushi.zip\", \"r\") as zip_ref:\n",
        "        print(\"Unzipping pizza, steak, sushi data...\") \n",
        "        zip_ref.extractall(image_path)\n",
        "\n",
        "    # Remove .zip file\n",
        "    os.remove(data_path / \"pizza_steak_sushi.zip\")\n",
        "\n",
        "# Setup Dirs\n",
        "train_dir = image_path / \"train\"\n",
        "test_dir = image_path / \"test\""
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Lt_CNQ4rKPmg",
        "outputId": "a1364d91-3afa-4401-94cb-94e4df837f06"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Did not find data/pizza_steak_sushi directory, creating one...\n",
            "Downloading pizza, steak, sushi data...\n",
            "Unzipping pizza, steak, sushi data...\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Prepare data"
      ],
      "metadata": {
        "id": "PGaMWWaoKQlM"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Create a transforms pipeline\n",
        "simple_transform = transforms.Compose([\n",
        "    transforms.Resize((224, 224)), # 1. Reshape all images to 224x224 (though some models may require different sizes)\n",
        "    transforms.ToTensor(), # 2. Turn image values to between 0 & 1 \n",
        "    transforms.Normalize(mean=[0.485, 0.456, 0.406], # 3. A mean of [0.485, 0.456, 0.406] (across each colour channel)\n",
        "                         std=[0.229, 0.224, 0.225]) # 4. A standard deviation of [0.229, 0.224, 0.225] (across each colour channel),\n",
        "])"
      ],
      "metadata": {
        "id": "VNIQNEQVKVXu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create training and testing DataLoader's as well as get a list of class names\n",
        "train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(train_dir=train_dir,\n",
        "                                                                               test_dir=test_dir,\n",
        "                                                                               transform=simple_transform, # resize, convert images to between 0 & 1 and normalize them\n",
        "                                                                               batch_size=32) # set mini-batch size to 32\n",
        "\n",
        "train_dataloader, test_dataloader, class_names"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Njd5lHTcKW23",
        "outputId": "fbc224df-8243-4e7b-90cd-49adc000fd47"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(<torch.utils.data.dataloader.DataLoader at 0x7f5f520c9bd0>,\n",
              " <torch.utils.data.dataloader.DataLoader at 0x7f5f520c9c90>,\n",
              " ['pizza', 'steak', 'sushi'])"
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Get and prepare a pretrained model"
      ],
      "metadata": {
        "id": "Ciw2DiRHKaSE"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Setup the model with pretrained weights and send it to the target device \n",
        "model_0 = torchvision.models.efficientnet_b0(pretrained=True).to(device)\n",
        "#model_0 # uncomment to output (it's very long)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 86,
          "referenced_widgets": [
            "6e25b4bb0d254191a793696a0f4f00ce",
            "37424313e66f474da42cfe1b512f09df",
            "58fd00f6a9114192a4fa757c1f669bff",
            "f115ea4b5fad4bb1910fca49ed3da8a1",
            "e8eba8e353e940ff9287929e41e4d656",
            "bc33539914a947ee89c271f10ea6a2bb",
            "6e03cb60fab94b7e92ce16c8178922dd",
            "5d464254c31d4516899643112fa0e958",
            "06df3ad4b7454556a43b6d61640b12f8",
            "0bdc7325c839439589a16c88876d6bd5",
            "873a483782894789bf0dee546a1b2d50"
          ]
        },
        "id": "snUuRXd8Kdk5",
        "outputId": "eac2a1e6-5607-437e-90b5-41639d17e5a8"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading: \"https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth\" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-3dd342df.pth\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0.00/20.5M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "6e25b4bb0d254191a793696a0f4f00ce"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Freeze all base layers in the \"features\" section of the model (the feature extractor) by setting requires_grad=False\n",
        "for param in model_0.features.parameters():\n",
        "    param.requires_grad = False"
      ],
      "metadata": {
        "id": "IbRhGvy_KeVL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Set the manual seeds\n",
        "torch.manual_seed(42)\n",
        "torch.cuda.manual_seed(42)\n",
        "\n",
        "# Get the length of class_names (one output unit for each class)\n",
        "output_shape = len(class_names)\n",
        "\n",
        "# Recreate the classifier layer and seed it to the target device\n",
        "model_0.classifier = torch.nn.Sequential(\n",
        "    torch.nn.Dropout(p=0.2, inplace=True), \n",
        "    torch.nn.Linear(in_features=1280, \n",
        "                    out_features=output_shape, # same number of output units as our number of classes\n",
        "                    bias=True)).to(device)"
      ],
      "metadata": {
        "id": "G1-6xV3ZKeSX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Train model"
      ],
      "metadata": {
        "id": "XQFaXX8CKePi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Define loss and optimizer\n",
        "loss_fn = nn.CrossEntropyLoss()\n",
        "optimizer = torch.optim.Adam(model_0.parameters(), lr=0.001)"
      ],
      "metadata": {
        "id": "exxU79eaKeM6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Set the random seeds\n",
        "torch.manual_seed(42)\n",
        "torch.cuda.manual_seed(42)\n",
        "\n",
        "# Start the timer\n",
        "from timeit import default_timer as timer \n",
        "start_time = timer()\n",
        "\n",
        "# Setup training and save the results\n",
        "model_0_results = engine.train(model=model_0,\n",
        "                       train_dataloader=train_dataloader,\n",
        "                       test_dataloader=test_dataloader,\n",
        "                       optimizer=optimizer,\n",
        "                       loss_fn=loss_fn,\n",
        "                       epochs=5,\n",
        "                       device=device)\n",
        "\n",
        "# End the timer and print out how long it took\n",
        "end_time = timer()\n",
        "print(f\"[INFO] Total training time: {end_time-start_time:.3f} seconds\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 153,
          "referenced_widgets": [
            "ae21171f17de45d895ab7a319dade609",
            "f9c60d9c0aed49faa993fd865fb09174",
            "755366e3f75e44c2b7a79bce78d77d11",
            "4a05e8d965124327a2329cf9e1eec984",
            "fe93ec079b384ac38a6f4d0e505431ff",
            "88dea77f1bcf44ffb69654515ee34f54",
            "2678a567b0414e1d9cfbfc2ecf5ffd30",
            "ce621be138a84f33b24c05b2d9cfd5f0",
            "1fa41d239a3a4845904434d057476a75",
            "f4827c6e36a1463fb0c82347f64230a2",
            "cea8f9c48bd8429998352a090173f537"
          ]
        },
        "id": "ComVkVtuKeKG",
        "outputId": "6d43205a-4e9f-4627-999a-40d07380cd58"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0/5 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ae21171f17de45d895ab7a319dade609"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 1 | train_loss: 1.0894 | train_acc: 0.4492 | test_loss: 0.9214 | test_acc: 0.5085\n",
            "Epoch: 2 | train_loss: 0.8697 | train_acc: 0.7734 | test_loss: 0.8036 | test_acc: 0.7434\n",
            "Epoch: 3 | train_loss: 0.7769 | train_acc: 0.7734 | test_loss: 0.7404 | test_acc: 0.7737\n",
            "Epoch: 4 | train_loss: 0.7244 | train_acc: 0.7422 | test_loss: 0.6488 | test_acc: 0.8864\n",
            "Epoch: 5 | train_loss: 0.6426 | train_acc: 0.7812 | test_loss: 0.6254 | test_acc: 0.8968\n",
            "[INFO] Total training time: 31.032 seconds\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Make predictions on the entire test dataset with the model"
      ],
      "metadata": {
        "id": "xFS4lE_IKyE_"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# TODO"
      ],
      "metadata": {
        "id": "DwZuCluFu375"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Make a confusion matrix with the test preds and the truth labels"
      ],
      "metadata": {
        "id": "Mb2bQ1b5K2WP"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Need the following libraries to make a confusion matrix:\n",
        "* torchmetrics - https://torchmetrics.readthedocs.io/en/stable/\n",
        "* mlxtend - http://rasbt.github.io/mlxtend/"
      ],
      "metadata": {
        "id": "5I2jpYAcM07s"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# See if torchmetrics exists, if not, install it\n",
        "try:\n",
        "    import torchmetrics, mlxtend\n",
        "    print(f\"mlxtend version: {mlxtend.__version__}\")\n",
        "    assert int(mlxtend.__version__.split(\".\")[1]) >= 19, \"mlxtend verison should be 0.19.0 or higher\"\n",
        "except:\n",
        "    !pip install -q torchmetrics -U mlxtend # <- Note: If you're using Google Colab, this may require restarting the runtime\n",
        "    import torchmetrics, mlxtend\n",
        "    print(f\"mlxtend version: {mlxtend.__version__}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qcKYZGWuK2P8",
        "outputId": "88c33b26-0b76-42d7-8a27-fb3073b1fc3f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[K     |████████████████████████████████| 409 kB 7.5 MB/s \n",
            "\u001b[K     |████████████████████████████████| 1.3 MB 45.7 MB/s \n",
            "\u001b[?25hmlxtend version: 0.19.0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Import mlxtend upgraded version\n",
        "import mlxtend \n",
        "print(mlxtend.__version__)\n",
        "assert int(mlxtend.__version__.split(\".\")[1]) >= 19 # should be version 0.19.0 or higher"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QOYVew4xMxgI",
        "outputId": "d3b393b8-09c3-46f7-c799-2f91ee4d30e6"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.19.0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# TODO"
      ],
      "metadata": {
        "id": "_5LU9-5Xu7dP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 2. Get the \"most wrong\" of the predictions on the test dataset and plot the 5 \"most wrong\" images. You can do this by:\n",
        "* Predicting across all of the test dataset, storing the labels and predicted probabilities.\n",
        "* Sort the predictions by *wrong prediction* and then *descending predicted probabilities*, this will give you the wrong predictions with the *highest* prediction probabilities, in other words, the \"most wrong\".\n",
        "* Plot the top 5 \"most wrong\" images, why do you think the model got these wrong?\n",
        "\n",
        "You'll want to:\n",
        "* Create a DataFrame with sample, label, prediction, pred prob\n",
        "* Sort DataFrame by correct (does label == prediction)\n",
        "* Sort DataFrame by pred prob (descending)\n",
        "* Plot the top 5 \"most wrong\" image predictions"
      ],
      "metadata": {
        "id": "YqlStPo-gbrF"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# TODO"
      ],
      "metadata": {
        "id": "cHtMeYHuvDwy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 3. Predict on your own image of pizza/steak/sushi - how does the model go? What happens if you predict on an image that isn't pizza/steak/sushi?\n",
        "* Here you can get an image from a website like http://www.unsplash.com to try it out or you can upload your own."
      ],
      "metadata": {
        "id": "1IvuTskxgjaw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# TODO: Get an image of pizza/steak/sushi\n"
      ],
      "metadata": {
        "id": "C16glgVFglmG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# TODO: Get an image of not pizza/steak/sushi\n"
      ],
      "metadata": {
        "id": "clA_KmihVYyA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 4. Train the model from section 4  in notebook 06 part 3 for longer (10 epochs should do), what happens to the performance?\n",
        "\n",
        "* See the model in notebook 06 part 3 for reference: https://www.learnpytorch.io/06_pytorch_transfer_learning/#3-getting-a-pretrained-model"
      ],
      "metadata": {
        "id": "Vzvi8GprgmJ0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# TODO: Recreate a new model \n"
      ],
      "metadata": {
        "id": "kIKg53Jna-Rt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# TODO: Train the model for 10 epochs"
      ],
      "metadata": {
        "id": "JhGT9igPgoF5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 5. Train the model from section 4 above with more data, say 20% of the images from Food101 of Pizza, Steak and Sushi images.\n",
        "* You can find the [20% Pizza, Steak, Sushi dataset](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/data/pizza_steak_sushi_20_percent.zip) on the course GitHub. It was created with the notebook [`extras/04_custom_data_creation.ipynb`](https://github.com/mrdbourke/pytorch-deep-learning/blob/main/extras/04_custom_data_creation.ipynb). \n"
      ],
      "metadata": {
        "id": "_oRrWPZTgoqL"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Get 20% data"
      ],
      "metadata": {
        "id": "VxyMMnUbgvw2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import requests\n",
        "import zipfile\n",
        "\n",
        "from pathlib import Path\n",
        "\n",
        "# Setup path to data folder\n",
        "data_path = Path(\"data/\")\n",
        "image_path = data_path / \"pizza_steak_sushi_20_percent\"\n",
        "image_data_zip_path = \"pizza_steak_sushi_20_percent.zip\"\n",
        "\n",
        "# If the image folder doesn't exist, download it and prepare it... \n",
        "if image_path.is_dir():\n",
        "    print(f\"{image_path} directory exists.\")\n",
        "else:\n",
        "    print(f\"Did not find {image_path} directory, creating one...\")\n",
        "    image_path.mkdir(parents=True, exist_ok=True)\n",
        "    \n",
        "    # Download pizza, steak, sushi data\n",
        "    with open(data_path / image_data_zip_path, \"wb\") as f:\n",
        "        request = requests.get(\"https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi_20_percent.zip\")\n",
        "        print(\"Downloading pizza, steak, sushi data...\")\n",
        "        f.write(request.content)\n",
        "\n",
        "    # Unzip pizza, steak, sushi data\n",
        "    with zipfile.ZipFile(data_path / image_data_zip_path, \"r\") as zip_ref:\n",
        "        print(\"Unzipping pizza, steak, sushi 20% data...\") \n",
        "        zip_ref.extractall(image_path)\n",
        "\n",
        "    # Remove .zip file\n",
        "    os.remove(data_path / image_data_zip_path)\n",
        "\n",
        "# Setup Dirs\n",
        "train_dir_20_percent = image_path / \"train\"\n",
        "test_dir_20_percent = image_path / \"test\"\n",
        "\n",
        "train_dir_20_percent, test_dir_20_percent"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "U_fdu5m2eKT9",
        "outputId": "121c61f3-f505-4302-b3b9-8b8bae5b5e1d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Did not find data/pizza_steak_sushi_20_percent directory, creating one...\n",
            "Downloading pizza, steak, sushi data...\n",
            "Unzipping pizza, steak, sushi 20% data...\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(PosixPath('data/pizza_steak_sushi_20_percent/train'),\n",
              " PosixPath('data/pizza_steak_sushi_20_percent/test'))"
            ]
          },
          "metadata": {},
          "execution_count": 33
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Create DataLoaders"
      ],
      "metadata": {
        "id": "SQj7eFdSe4Fv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Create a transforms pipeline\n",
        "simple_transform = transforms.Compose([\n",
        "    transforms.Resize((224, 224)), # 1. Reshape all images to 224x224 (though some models may require different sizes)\n",
        "    transforms.ToTensor(), # 2. Turn image values to between 0 & 1 \n",
        "    transforms.Normalize(mean=[0.485, 0.456, 0.406], # 3. A mean of [0.485, 0.456, 0.406] (across each colour channel)\n",
        "                         std=[0.229, 0.224, 0.225]) # 4. A standard deviation of [0.229, 0.224, 0.225] (across each colour channel),\n",
        "])"
      ],
      "metadata": {
        "id": "TEG_k785e7Jw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create training and testing DataLoader's as well as get a list of class names\n",
        "train_dataloader_20_percent, test_dataloader_20_percent, class_names = data_setup.create_dataloaders(train_dir=train_dir_20_percent,\n",
        "                                                                                                     test_dir=test_dir_20_percent,\n",
        "                                                                                                     transform=simple_transform, # resize, convert images to between 0 & 1 and normalize them\n",
        "                                                                                                     batch_size=32) # set mini-batch size to 32\n",
        "\n",
        "train_dataloader_20_percent, test_dataloader_20_percent, class_names"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "82x7LnQJe7H5",
        "outputId": "342fd4e7-0656-495a-aee0-0d23be130438"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(<torch.utils.data.dataloader.DataLoader at 0x7f5ede28e390>,\n",
              " <torch.utils.data.dataloader.DataLoader at 0x7f5ede28e210>,\n",
              " ['pizza', 'steak', 'sushi'])"
            ]
          },
          "metadata": {},
          "execution_count": 35
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Get a pretrained model"
      ],
      "metadata": {
        "id": "qROl77sKfIOd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# TODO"
      ],
      "metadata": {
        "id": "PHWNZ6yDvpR8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Train a model with 20% of the data"
      ],
      "metadata": {
        "id": "UqffJfOIfp3T"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# TODO"
      ],
      "metadata": {
        "id": "wXpYOYeTvp7a"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 6. Try a different model from [`torchvision.models`](https://pytorch.org/vision/stable/models.html) on the Pizza, Steak, Sushi data, how does this model perform?\n",
        "* You'll have to change the size of the classifier layer to suit our problem.\n",
        "* You may want to try an EfficientNet with a higher number than our B0, perhaps `torchvision.models.efficientnet_b2()`?\n",
        "  * **Note:** Depending on the model you use you will have to prepare/transform the data in a certain way."
      ],
      "metadata": {
        "id": "Ibj4UPjRgvly"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# TODO "
      ],
      "metadata": {
        "id": "3FQ8tL7El7eO"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}